"""
Command‑line entry point to run the Kernel→Metric simulation on multiple
conditions.  This script reads a YAML configuration file describing
lattice sizes, gauge groups, smoothing scales, and other parameters.
For each combination it builds the kernel envelope, solves the
Poisson equation, computes lensing, performs control experiments, and
writes results to disk.

Usage::

    python -m k2m.run_all --config configs/anchors.yaml --output runs

The results for each condition are written as JSON (metadata) and NPZ
(arrays) into the specified output directory.  A manifest with SHA‑256
hashes of inputs and code is included in the JSON for reproducibility.
"""

from __future__ import annotations

import argparse
import json
import os
from datetime import datetime
from pathlib import Path

import numpy as np
import yaml

from . import envelope
from . import poisson
from . import optics
from . import analysis
from . import io_fphs


def _compute_config_hash(config: dict) -> str:
    """Compute a SHA‑256 hash of a configuration dictionary.

    The dictionary is serialised with sorted keys to ensure a stable
    representation.  The resulting bytes are hashed and returned as
    hexadecimal.
    """
    import hashlib
    payload = json.dumps(config, sort_keys=True).encode("utf-8")
    return hashlib.sha256(payload).hexdigest()


def run_condition(
    gauge: str,
    L: int,
    ell: int,
    lambd: float,
    b: float,
    kappa: float,
    seed: int,
    data_dir: str,
    output_dir: str,
    decay_alpha: float | None = None,
) -> None:
    """Run the kernel→metric simulation for a single condition.

    All intermediate arrays are saved to an NPZ file and summary
    statistics are written to a JSON file.  The filenames encode the
    gauge, lattice size and smoothing width.
    """
    # Derive gauge factor (allow amplitude differences between gauge groups)
    gauge_factor = 1.0
    if gauge.upper() == "SU2":
        # Scale SU(2) envelope to illustrate amplitude differences; this
        # factor is heuristic since no direct SU(2) kernel was available.
        gauge_factor = 0.8
    elif gauge.upper() == "SU3":
        gauge_factor = 1.0
    else:
        raise ValueError(f"Unknown gauge group: {gauge}")

    # Build envelope and normalised gradient magnitude
    E0, E_smooth, grad_x, grad_y, G_hat = envelope.build_envelope(
        L=L,
        ell=ell,
        data_dir=data_dir,
        gauge_factor=gauge_factor,
        decay_alpha=decay_alpha,
    )

    # Solve Poisson equation
    V, E_x, E_y = poisson.compute_potential(G_hat, lambd, ell)

    # Radial profiles
    radii_V, V_radial, counts_V = analysis.radial_profile(V)
    radii_E, E_mag_radial, counts_E = analysis.radial_profile(np.sqrt(E_x**2 + E_y**2))

    # Select fit windows and fit slopes for potential and field
    start_V, end_V, slope_V, intercept_V, R2_V = analysis.select_fit_window(
        radii_V, V_radial, ell=ell, L=L
    )
    slope_V_err = analysis.bootstrap_slope_error(
        radii_V, V_radial, start_V, end_V, n_resamples=200, seed=seed
    )

    start_E, end_E, slope_E, intercept_E, R2_E = analysis.select_fit_window(
        radii_E, E_mag_radial, ell=ell, L=L
    )
    slope_E_err = analysis.bootstrap_slope_error(
        radii_E, E_mag_radial, start_E, end_E, n_resamples=200, seed=seed
    )

    # Gauss‑law plateau
    flux_mean, flux_std = analysis.gauss_law_plateau(
        radii_E, E_mag_radial, start_E, end_E
    )

    # Amplitude scaling: prefactor V(r)*r at mid‑range
    amp_pref = analysis.amplitude_at_radius(V_radial, radii_V, start_V, end_V)

    # Optical translation
    b_vals, alpha_vals, alpha_slope, alpha_R2, b_range = optics.compute_deflection_curve(
        G_hat, lambd, ell, L
    )

    # Controls: shuffle envelope
    rng = np.random.default_rng(seed)
    E0_shuf_flat = E0.ravel().copy()
    rng.shuffle(E0_shuf_flat)
    E0_shuf = E0_shuf_flat.reshape(E0.shape)
    E_shuf_smooth, grad_shuf_x, grad_shuf_y = envelope.smooth_and_gradient(E0_shuf, ell)
    G_hat_shuf = envelope.normalise_gradient_magnitude(grad_shuf_x, grad_shuf_y)
    V_shuf, Ex_shuf, Ey_shuf = poisson.compute_potential(G_hat_shuf, lambd, ell)
    radii_V_shuf, V_shuf_radial, _ = analysis.radial_profile(V_shuf)
    radii_E_shuf, E_mag_shuf_radial, _ = analysis.radial_profile(np.sqrt(Ex_shuf**2 + Ey_shuf**2))
    start_V_shuf, end_V_shuf, sV_shuf, _, R2_V_shuf = analysis.select_fit_window(
        radii_V_shuf, V_shuf_radial, ell=ell, L=L
    )
    start_E_shuf, end_E_shuf, sE_shuf, _, R2_E_shuf = analysis.select_fit_window(
        radii_E_shuf, E_mag_shuf_radial, ell=ell, L=L
    )

    # Controls: uniform envelope (constant gradient zero)
    E_uniform = np.ones_like(E0, dtype=float)
    E_unif_smooth, grad_u_x, grad_u_y = envelope.smooth_and_gradient(E_uniform, ell)
    G_hat_unif = envelope.normalise_gradient_magnitude(grad_u_x, grad_u_y)
    V_unif, Ex_unif, Ey_unif = poisson.compute_potential(G_hat_unif, lambd, ell)
    radii_V_unif, V_unif_radial, _ = analysis.radial_profile(V_unif)
    radii_E_unif, E_mag_unif_radial, _ = analysis.radial_profile(np.sqrt(Ex_unif**2 + Ey_unif**2))
    start_V_unif, end_V_unif, sV_unif, _, R2_V_unif = analysis.select_fit_window(
        radii_V_unif, V_unif_radial, ell=ell, L=L
    )
    start_E_unif, end_E_unif, sE_unif, _, R2_E_unif = analysis.select_fit_window(
        radii_E_unif, E_mag_unif_radial, ell=ell, L=L
    )

    controls = {
        "shuffle_slopes": {
            "sV": float(sV_shuf),
            "sE": float(sE_shuf),
            "R2s": {"V": float(R2_V_shuf), "E": float(R2_E_shuf)},
        },
        "uniform_slopes": {
            "sV": float(sV_unif),
            "sE": float(sE_unif),
        },
    }

    # Hashes for provenance
    kernel_hash = io_fphs.sha256_of_array(E0)
    pivot_hash = io_fphs.sha256_of_file(os.path.join(data_dir, "pivot_params.json"))
    config = {
        "gauge": gauge,
        "L": L,
        "b": b,
        "kappa": kappa,
        "lambda": lambd,
        "smoothing_ell": ell,
        "gauge_factor": gauge_factor,
        "seed": seed,
        "decay_alpha": decay_alpha,
    }
    config_hash = _compute_config_hash(config)

    # Prepare result dictionary
    result = {
        "sim": "kernel_to_metric",
        "gauge": gauge,
        "L": L,
        "b": b,
        "kappa": kappa,
        "lambda": lambd,
        "smoothing_ell": ell,
        "translator": {"type": "poisson", "norm": "mean_grad"},
        "fit_window": {"r_min": float(radii_V[start_V]), "r_max": float(radii_V[end_V])},
        "radial": {
            "slope_V": float(slope_V),
            "slope_V_err": float(slope_V_err),
            "R2_V": float(R2_V),
            "slope_E": float(slope_E),
            "slope_E_err": float(slope_E_err),
            "R2_E": float(R2_E),
            "flux_plateau": {"mean": float(flux_mean), "std": float(flux_std)},
            "amplitude_prefactor": float(amp_pref),
        },
        "lensing": {
            "alpha_vs_invb_slope": float(alpha_slope),
            "alpha_vs_invb_R2": float(alpha_R2),
            "b_range": [float(b_range[0]), float(b_range[1])],
        },
        "controls": controls,
        "artifacts": {
            "kernel_hash": kernel_hash,
            "pivot_fit_hash": pivot_hash,
            "config_hash": config_hash,
            "seed": seed,
            "timestamp": datetime.utcnow().isoformat() + "Z",
        },
    }

    # Determine file names
    base_name = f"{gauge}_L{L}_ell{ell}"
    json_path = Path(output_dir) / f"{base_name}.json"
    npz_path = Path(output_dir) / f"{base_name}.npz"

    # Save arrays to NPZ
    np.savez_compressed(
        npz_path,
        E0=E0,
        E_smooth=E_smooth,
        G_hat=G_hat,
        V=V,
        E_x=E_x,
        E_y=E_y,
        V_radial=V_radial,
        E_mag_radial=E_mag_radial,
        radii_V=radii_V,
        radii_E=radii_E,
        b_vals=b_vals,
        alpha_vals=alpha_vals,
    )

    # Write JSON summary
    with open(json_path, "w", encoding="utf-8") as f:
        json.dump(result, f, indent=2)

    print(f"Finished {gauge} L={L} ell={ell} → {json_path}")


def main():
    parser = argparse.ArgumentParser(description="Run kernel→metric simulation across multiple conditions.")
    parser.add_argument("--config", type=str, required=True, help="Path to YAML config file")
    parser.add_argument("--data-dir", type=str, default="data", help="Directory containing D_values.csv and pivot_params.json")
    parser.add_argument("--output", type=str, default="runs", help="Directory to write JSON and NPZ outputs")
    args = parser.parse_args()

    # Ensure output directory exists
    os.makedirs(args.output, exist_ok=True)

    # Load YAML configuration
    with open(args.config, "r", encoding="utf-8") as f:
        config = yaml.safe_load(f)

    gauges = config.get("gauge", [])
    L_list = config.get("L", [])
    ell_list = config.get("ell", [])
    b = float(config.get("b", 0.0))
    kappa = float(config.get("kappa", 0.0))
    lambd = float(config.get("lambda", 0.0))
    seed = int(config.get("seed", 0))
    # Optional radial decay parameter; controls localisation of envelope
    decay_alpha = config.get("decay_alpha", None)


    for gauge in gauges:
        for L in L_list:
            for ell in ell_list:
                run_condition(
                    gauge=gauge,
                    L=L,
                    ell=ell,
                    lambd=lambd,
                    b=b,
                    kappa=kappa,
                    seed=seed,
                    data_dir=args.data_dir,
                    output_dir=args.output,
                    decay_alpha=decay_alpha,
                )

    # Write a manifest of input hashes and configuration
    manifest = {}
    # Hashes of input files
    d_values_path = os.path.join(args.data_dir, "D_values.csv")
    pivot_path = os.path.join(args.data_dir, "pivot_params.json")
    if os.path.exists(d_values_path):
        manifest["D_values_hash"] = io_fphs.sha256_of_file(d_values_path)
    if os.path.exists(pivot_path):
        manifest["pivot_params_hash"] = io_fphs.sha256_of_file(pivot_path)
    # Hash of the YAML configuration
    manifest["anchors_hash"] = io_fphs.sha256_of_file(args.config)
    manifest["generated_at"] = datetime.utcnow().isoformat() + "Z"
    # Attempt to record code version (fallback to unknown)
    try:
        import subprocess
        commit = subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=Path(__file__).resolve().parents[2], stderr=subprocess.DEVNULL).decode().strip()
    except Exception:
        commit = "unknown"
    manifest["code_commit"] = commit
    manifest_path = Path(args.output) / "MANIFEST.json"
    with open(manifest_path, "w", encoding="utf-8") as mf:
        json.dump(manifest, mf, indent=2)
    print(f"Wrote manifest to {manifest_path}")


if __name__ == "__main__":
    main()
